# library.jl is using LinearAlgebra, NPZ, PyPlot, BenchmarkTools, and Random
include("library.jl")
using Match, Printf, JSON

#= INTRODUCTION
This script runs parallel tempering Monte Carlo using library.jl. The script's arguments
are tailored to be able to be looped over for use on a cluster, while parameters that do
not change for a given phase diagram (J parameters, temperatures, etc.) are read from a
JSON file in the output folder.
=#

#= COMMANDLINE ARGUMENTS
Usage:
$ julia -t <n_threads> PT_script.jl <path> <J_parameter_index> <ribbon_index> <rounds> <rounds per state>

Arguments:
- n_threads must be manually specified, otherwise julia will start with 1 thread. The
  option -t is the same as --threads.
- path is the path to the folder where the parameters and output files will be. Important:
  there must be a file at path/params.json; see the specification below.
- J_parameter_index represents the index into the array "JParams" in the parameters.
- ribbon_index is tricky: it may be iterated from 1 upwards, but it represents a set of
  n_threads realizations done at this particular J parameter. So if n_threads were 5, then
  a ribbon_index of 2 would refer to realizations 6 through 10.
- rounds specifies the number of rounds to perform, where a single "round" does N Metropolis
  steps (N is the number of spins) and the parallel tempering swaps. The program always does
  this many rounds, even if it is continuing from a previous run. It increments values in the
  prior data's histogram.
- rounds per state indicates when states are taken; if this were 10, then a state would be taken every 10
  rounds (starting with the end of round 10, with rounds starting at 1).
=#

#= JSON SPECIFICATION
There must be a file at <path>/params.json. JSON files represent nested dictionaries,
essentially, where the order of the key-value pairs does not matter. This script has certain
keys that are required while the rest are optional. Parameters should not be changed once
runs have begun in its folder. The parameter details are best illustrated with an example.
The text in angle brackets are comments (not allowed in real JSON):

{
<Required parameters>
	"NThreads": 2, <The number of threads>
	"NSpins": 30,  <The number of spins>
	"JFunction": "SK_J", <The J function; the code runs the function of this name.>
	"temps": [0.19952623, 0.27542287, 0.3801894, 0.52480746, 0.72443596, 1.0, 1.38038426,
              1.90546072, 2.63026799, 3.63078055, 5.01187234], <The different temperature
                ensembles used in the parallel tempering runs>
	"JParams": [0.0, 0.3, 0.6, 0.9, 1.2, 1.5, 1.8], <The J parameters>
<====================-=====================================================================>
<Optional parameters>
	"JScale": 1.0, <Multiply all J_ij by this value; default 1.0>
	"motional": false, <Whether the degrees of freedom should be angles instead of bits;
                        default false>
	"onDiag": false, <Whether to keep the diagonal elements of J_ij nonzero; default false>
	"useK": false, <Whether to use a K matrix [only implemented in SK_J]; default false>
	"KFrac": 0.0, <The ratio of K's standard deviation to J's; default 0.5>
	"unrotateOverlap": false, <When true, track q_xx and q_yy instead of these rotated 45
                               degrees; default false>
	"histLen": 5, <Number of bins in the histograms; default NSpins + 1>
	"saveInterval": 20000, <Write to disk each time this number of steps have passed;
                            default 1000000>
	"pollBase": 2.0, <The base of the polling intervals; default sqrt(2.0)>
	"fileStem": "sample" <Output files are {fileStem}_dddd_dddd.npz; default "PT_out">
}

The polling intervals are the powers of pollBase that lie between the constant values of 60
and 1 trillion.

=#

#= OUTPUT FILE FORMAT
The code will output files in the same directory as the .json file with the filename:
{fileStem}_{JParamIndex}_{Realization}.npz.
These are uncompressed .npz files that should have the following tables:
 -'mags', 'overlaps': arrays of shape (histLen, NTemps) that give the histograms of the
   magnetizations and overlaps;
 -'mPollData', 'qPollData': arrays of shape (NTemps, k, nPollTimes) that give the result of
   polling a k-vector from the magnetization and overlap at each temperature and at each
   polling time.
 -'polledTimes': vector of size nPollTimes giving the times of the above measurements.
 -'spins1', 'spins2': arrays of shape (NSpins, NTemps) that give the spin states themselves.
   This is mainly to be used as input for the continue functionality rather than to be read
   as useful output.
=#

#= EXAMPLE RUN
An example run with the above JSON in Sample/params.json would be:
$ julia -t 2 PT_script.jl Sample 2 3 2000

This has J_parameter_index of 2, so it would use a J parameter of 0.3. The ribbon index of 3
with an NThreads of 2 means that this will run realizations 5 and 6 of this J parameter
(regardless of whether realizations 1 through 4 exist). It will run 2000 rounds, and the
histograms will have size 5 and hold data from the final step. Data arrays holding the mean
of the absolute value of the magnetization and the Binder cumulant of the overlap will be
generated for each temperature and each time that is a power of 2 between 60 and 2000, as well
as the final step. It will be run with the specified 11 temperatures and with 30 spins.

We may check this by loading the output:
$ python
>>> import numpy as np
>>> out = np.load("Sample/sample_0002_0006.npz")
>>> overlaps = out["overlaps"]
>>> print(overlaps.shape)
(5, 11)

This indicates that there are 5 bins per histogram at each of the 11 temperatures.
The poll times are also given, and align with the input poll base of 2:
>>> print(out["polledTimes"])
[  64  128  256  512 1024 2000]

These are indeed powers of 2 above 60, and below a trillion and the final time. Finally, if
we wanted to load parameter values too, Python can do this too:
>>> import json
>>> with open("Sample/params.json", encoding="utf-8") as paramFile:
...     params = json.load(paramFile)
...
>>> print(params["temps"][1])
0.27542287

This is indeed the second input temperature.

Continue functionality: if we were to run this again, ensuring that Python no longer has a
lock on the .npz files, we could do:

$ julia -t 2 PT_script.jl Sample 2 3 3000

We should have done 5000 steps total, and indeed we see this:
$ python
>>> import numpy as np
>>> out = np.load("Sample/sample_0002_0006.npz")
>>> print(out["polledTimes"])
[  64  128  256  512 1024 2000 2048 4096 5000]

The data at time 2000 were not overwritten.
This continue functionality works because the spin states are stored in the output as well:
>>> print(out["spins1"].shape)
(30, 11)

Notice that there are 30 spins and 11 temperatures.
=#

#= BENCHMARKS
I have not tested this particularly rigorously. My Core i5-9400F was able to run 100 spins,
20 temperatures, and 10^5 rounds in around 46-47 seconds per thread, amounting to roughly
4 x 10^8 operations per second. That seems plenty fast for a 2.90 GHz processor. I used the
default Julia optimization level, which I believe to be O2.
=#

function main()
    # Constants; be careful changing these.
    paramFile = "params.json"
    pollBaseDefault = 2^(1/2)
    fileStemDefault = "PT_out"
    saveIntervalDefault = 1_000_000
    pollMin = 60
    pollMax = 1_000_000_000_000

    # Read the commandline arguments.
    filePath = ARGS[1]
    idxJParam = parse(Int64, ARGS[2])
    idxScriptJReal = parse(Int64, ARGS[3])
    rounds = parse(Int64, ARGS[4])
    roundsPerState = 0

    if size(ARGS)[1] > 4
      roundsPerState = parse(Int64, ARGS[5])
    end

    # Read the parameter JSON file.
    params = JSON.parsefile(filePath * "/" * paramFile)
    NThreads = params["NThreads"]
    N = params["NSpins"]
    JFunction = params["JFunction"]
    betas = 1 ./ params["temps"]
    JParam = params["JParams"][idxJParam]
    # Set optional parameters to the input value, or a default value otherwise.
    JScale = get(params, "JScale", 1.0)
    motional = get(params, "motional", false)
    onDiag = get(params, "onDiag", false)
    useK = get(params, "useK", false)
    KFrac = get(params, "Kfrac", 0.5)
    histLen = get(params, "histLen", N + 1)
    saveInterval = get(params, "saveInterval", saveIntervalDefault)
    pollBase = get(params, "pollBase", pollBaseDefault)
    fileStem = get(params, "fileStem", fileStemDefault)
    unrotateOverlapTensor = get(params, "unrotateOverlap", false)

    # Set up the polling times. Currently they are just pollBase^n, restricted to lie
    # within pollMin and pollMax. Relies on the geomspace function in library.jl.
    lo = ceil(Int, log(pollBase, pollMin))
    hi = floor(Int, log(pollBase, pollMax))
    allPollTimes = geomspace(pollBase^lo, pollBase^hi, hi - lo + 1)
    allPollTimes = round.(Int, allPollTimes)

    Threads.@threads for i = 1:NThreads
        # Each thread corresponds to a point in the [JParam, Realization]-space; use the
        # indices to (a) seed the RNG and (b) determine the filename.
        idxJReal = (idxScriptJReal - 1) * NThreads + i
        rng = Xoshiro(convert.(UInt, [idxJParam, idxJReal]))
        J = getfield(Main, Symbol(JFunction))(JParam, N, rng=rng, JScale=JScale, onDiag=onDiag, useK=useK, KFrac=KFrac)
        outFile = filePath * "/" * fileStem * (@sprintf "_%04d_%04d.npz" idxJParam idxJReal)

        # Once we know the filename, we can check whether we are continuing an existing
        # run or starting a new run.
        if isfile(outFile)
            inDict = npzread(outFile)
            spins1 = inDict["spins1"]
            spins2 = inDict["spins2"]
            mags = inDict["mags"]
            overlaps = inDict["overlaps"]
            mPollData = inDict["mPollData"]
            qPollData = inDict["qPollData"]
            polledTimes = inDict["polledTimes"]
        else
            NTemps = size(betas)[1]
            polledTimes = Vector{UInt64}()
            mPollData = zeros(Float64, NTemps, 1, 1)
            qPollData = zeros(Float64, NTemps, 1, 1)
            spins1 = sign.(0.5 .- rand(N, NTemps))
            spins2 = sign.(0.5 .- rand(N, NTemps))
            mags = zeros(UInt64, histLen, NTemps)
            overlaps = zeros(UInt64, histLen, NTemps)
            if motional
                mPollData = zeros(Float64, NTemps, 2, 1)
                qPollData = zeros(Float64, NTemps, 2, 1)
                spins1 = 2 .* π .* rand(N, NTemps)
                spins2 = 2 .* π .* rand(N, NTemps)
                mags = zeros(UInt64, histLen, histLen, NTemps)
                overlaps = zeros(UInt64, histLen, histLen, NTemps)
            end
        end
        if motional
            overlapFunc = overlapTensorRotated
            if unrotateOverlapTensor
              overlapFunc = overlapTensor
            end
            PT_poll_save_mo!(spins1, spins2, betas, J, mags, overlaps, mPollData, qPollData,
                polledTimes, rounds, allPollTimes, saveInterval, outFile, roundsPerState; overlapFunc=overlapFunc)
        else
            PT_poll_save!(spins1, spins2, betas, J, mags, overlaps, mPollData, qPollData,
                polledTimes, rounds, allPollTimes, saveInterval, outFile, roundsPerState)
        end
    end
end

main()
